{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "colab_type": "text",
    "id": "view-in-github"
   },
   "source": [
    "<a href=\"https://colab.research.google.com/github/Ankur3107/GitHub-Bugs-Prediction-Challenge/blob/main/model-template/MachineHack_roberta_base.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 595
    },
    "id": "nUbs9bOCnBTN",
    "outputId": "144fbf23-bec5-4ba8-a5fc-ae649d47270d"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Collecting transformers\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/19/22/aff234f4a841f8999e68a7a94bdd4b60b4cebcfeca5d67d61cd08c9179de/transformers-3.3.1-py3-none-any.whl (1.1MB)\n",
      "\u001b[K     |████████████████████████████████| 1.1MB 3.4MB/s \n",
      "\u001b[?25hRequirement already satisfied: numpy in /usr/local/lib/python3.6/dist-packages (from transformers) (1.18.5)\n",
      "Requirement already satisfied: filelock in /usr/local/lib/python3.6/dist-packages (from transformers) (3.0.12)\n",
      "Requirement already satisfied: dataclasses; python_version < \"3.7\" in /usr/local/lib/python3.6/dist-packages (from transformers) (0.7)\n",
      "Requirement already satisfied: requests in /usr/local/lib/python3.6/dist-packages (from transformers) (2.23.0)\n",
      "Requirement already satisfied: tqdm>=4.27 in /usr/local/lib/python3.6/dist-packages (from transformers) (4.41.1)\n",
      "Collecting sentencepiece!=0.1.92\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/d4/a4/d0a884c4300004a78cca907a6ff9a5e9fe4f090f5d95ab341c53d28cbc58/sentencepiece-0.1.91-cp36-cp36m-manylinux1_x86_64.whl (1.1MB)\n",
      "\u001b[K     |████████████████████████████████| 1.1MB 17.5MB/s \n",
      "\u001b[?25hRequirement already satisfied: regex!=2019.12.17 in /usr/local/lib/python3.6/dist-packages (from transformers) (2019.12.20)\n",
      "Collecting tokenizers==0.8.1.rc2\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/80/83/8b9fccb9e48eeb575ee19179e2bdde0ee9a1904f97de5f02d19016b8804f/tokenizers-0.8.1rc2-cp36-cp36m-manylinux1_x86_64.whl (3.0MB)\n",
      "\u001b[K     |████████████████████████████████| 3.0MB 27.2MB/s \n",
      "\u001b[?25hCollecting sacremoses\n",
      "\u001b[?25l  Downloading https://files.pythonhosted.org/packages/7d/34/09d19aff26edcc8eb2a01bed8e98f13a1537005d31e95233fd48216eed10/sacremoses-0.0.43.tar.gz (883kB)\n",
      "\u001b[K     |████████████████████████████████| 890kB 26.1MB/s \n",
      "\u001b[?25hRequirement already satisfied: packaging in /usr/local/lib/python3.6/dist-packages (from transformers) (20.4)\n",
      "Requirement already satisfied: idna<3,>=2.5 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2.10)\n",
      "Requirement already satisfied: chardet<4,>=3.0.2 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (3.0.4)\n",
      "Requirement already satisfied: urllib3!=1.25.0,!=1.25.1,<1.26,>=1.21.1 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (1.24.3)\n",
      "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.6/dist-packages (from requests->transformers) (2020.6.20)\n",
      "Requirement already satisfied: six in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (1.15.0)\n",
      "Requirement already satisfied: click in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (7.1.2)\n",
      "Requirement already satisfied: joblib in /usr/local/lib/python3.6/dist-packages (from sacremoses->transformers) (0.16.0)\n",
      "Requirement already satisfied: pyparsing>=2.0.2 in /usr/local/lib/python3.6/dist-packages (from packaging->transformers) (2.4.7)\n",
      "Building wheels for collected packages: sacremoses\n",
      "  Building wheel for sacremoses (setup.py) ... \u001b[?25l\u001b[?25hdone\n",
      "  Created wheel for sacremoses: filename=sacremoses-0.0.43-cp36-none-any.whl size=893257 sha256=2e9363558ea8887396422c1b90e23e7cced22406d1ce44ff82dd836a65f0bca3\n",
      "  Stored in directory: /root/.cache/pip/wheels/29/3c/fd/7ce5c3f0666dab31a50123635e6fb5e19ceb42ce38d4e58f45\n",
      "Successfully built sacremoses\n",
      "Installing collected packages: sentencepiece, tokenizers, sacremoses, transformers\n",
      "Successfully installed sacremoses-0.0.43 sentencepiece-0.1.91 tokenizers-0.8.1rc2 transformers-3.3.1\n"
     ]
    }
   ],
   "source": [
    "!pip install transformers"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 374
    },
    "id": "HlVyn8uLoxED",
    "outputId": "ffb04be8-bf85-41b1-be69-f15d4fb2cd6a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "--2020-10-15 12:36:30--  https://machinehack-be.s3.amazonaws.com/predict_github_issues_embold_sponsored_hackathon/Embold_Participant%27s_Dataset.zip\n",
      "Resolving machinehack-be.s3.amazonaws.com (machinehack-be.s3.amazonaws.com)... 52.219.62.68\n",
      "Connecting to machinehack-be.s3.amazonaws.com (machinehack-be.s3.amazonaws.com)|52.219.62.68|:443... connected.\n",
      "HTTP request sent, awaiting response... 200 OK\n",
      "Length: 102320961 (98M) [application/octet-stream]\n",
      "Saving to: ‘Embold_Participant's_Dataset.zip’\n",
      "\n",
      "Embold_Participant' 100%[===================>]  97.58M  12.3MB/s    in 9.5s    \n",
      "\n",
      "2020-10-15 12:36:41 (10.3 MB/s) - ‘Embold_Participant's_Dataset.zip’ saved [102320961/102320961]\n",
      "\n",
      "Archive:  ./Embold_Participant's_Dataset.zip\n",
      "   creating: Dataset/Embold_Participant's_Dataset/\n",
      "  inflating: Dataset/Embold_Participant's_Dataset/sample submission.csv  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._sample submission.csv  \n",
      "  inflating: Dataset/Embold_Participant's_Dataset/embold_train_extra.json  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._embold_train_extra.json  \n",
      "  inflating: Dataset/Embold_Participant's_Dataset/embold_test.json  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._embold_test.json  \n",
      "  inflating: Dataset/Embold_Participant's_Dataset/embold_train.json  \n",
      "  inflating: Dataset/__MACOSX/Embold_Participant's_Dataset/._embold_train.json  \n"
     ]
    }
   ],
   "source": [
    "!wget https://machinehack-be.s3.amazonaws.com/predict_github_issues_embold_sponsored_hackathon/Embold_Participant%27s_Dataset.zip\n",
    "!unzip ./Embold_Participant\\'s_Dataset.zip -d Dataset"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34
    },
    "id": "SKNytrlboyvX",
    "outputId": "7ac69944-c5c8-4ce2-9f55-645fd545cadc"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "/content/Dataset/Embold_Participant's_Dataset\n"
     ]
    }
   ],
   "source": [
    "cd \"Dataset/Embold_Participant's_Dataset/\""
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "54jZSwUio4wf"
   },
   "outputs": [],
   "source": [
    "#import pandas as pd\n",
    "import numpy as np\n",
    "import pandas as pd"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XOk2fr4qo6lF"
   },
   "outputs": [],
   "source": [
    "train_df = pd.read_json(\"embold_train.json\").reset_index(drop=True)\n",
    "test_df = pd.read_json(\"embold_test.json\").reset_index(drop=True)\n",
    "train_ex_df = pd.read_json(\"embold_train_extra.json\")"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "Z_oDZymuo75N"
   },
   "outputs": [],
   "source": [
    "test_label_df = pd.read_csv('/content/ensemble_v8_1.csv')\n",
    "test_df['label'] = test_label_df['label']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ucvTg3mlo_Yy"
   },
   "outputs": [],
   "source": [
    "train_data = train_df.append(train_ex_df)\n",
    "test_df['text'] = test_df['title']+' '+test_df['body']\n",
    "train_data['text'] = train_data['title']+' '+train_data['body']"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "XM_M2kIzpA_Q"
   },
   "outputs": [],
   "source": [
    "from dataclasses import dataclass\n",
    "import numpy as np\n",
    "import pandas as pd\n",
    "import tensorflow as tf\n",
    "from tensorflow.keras.optimizers import Adam\n",
    "from transformers import *\n",
    "from transformers import pipeline\n",
    "from pprint import pprint\n",
    "from tensorflow import keras"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 765
    },
    "id": "7eRRqZdzpCe-",
    "outputId": "365fdfd1-ebed-423d-c802-8e9e9f87528e"
   },
   "outputs": [
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:absl:Entering into master device scope: /job:worker/replica:0/task:0/device:CPU:0\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Running on TPU  grpc://10.17.198.186:8470\n",
      "INFO:tensorflow:Initializing the TPU system: grpc://10.17.198.186:8470\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Initializing the TPU system: grpc://10.17.198.186:8470\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Clearing out eager caches\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Clearing out eager caches\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Finished initializing TPU system.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Finished initializing TPU system.\n",
      "WARNING:absl:`tf.distribute.experimental.TPUStrategy` is deprecated, please use  the non experimental symbol `tf.distribute.TPUStrategy` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Found TPU system:\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:Found TPU system:\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores: 8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Workers: 1\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Workers: 1\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Num TPU Cores Per Worker: 8\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:localhost/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:CPU:0, CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:0, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:1, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:2, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:3, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:4, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:5, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:6, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU:7, TPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:TPU_SYSTEM:0, TPU_SYSTEM, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "INFO:tensorflow:*** Available Device: _DeviceAttributes(/job:worker/replica:0/task:0/device:XLA_CPU:0, XLA_CPU, 0, 0)\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "REPLICAS:  8\n"
     ]
    }
   ],
   "source": [
    "@dataclass\n",
    "class Config:\n",
    "    MAX_LEN = 320\n",
    "    BATCH_SIZE = 16  # per TPU core\n",
    "    TOTAL_STEPS = 2000  # thats approx 4 epochs\n",
    "    EVALUATE_EVERY = 200\n",
    "    LR = 1e-5\n",
    "    PRETRAINED_MODEL = \"roberta-large\"  # huggingface bert model\n",
    "\n",
    "\n",
    "transformer_flags = Config()\n",
    "\n",
    "\n",
    "def connect_to_TPU():\n",
    "    \"\"\"Detect hardware, return appropriate distribution strategy\"\"\"\n",
    "    try:\n",
    "        # TPU detection. No parameters necessary if TPU_NAME environment variable is\n",
    "        # set: this is always the case on Kaggle.\n",
    "        tpu = tf.distribute.cluster_resolver.TPUClusterResolver()\n",
    "        print(\"Running on TPU \", tpu.master())\n",
    "    except ValueError:\n",
    "        tpu = None\n",
    "\n",
    "    if tpu:\n",
    "        tf.config.experimental_connect_to_cluster(tpu)\n",
    "        tf.tpu.experimental.initialize_tpu_system(tpu)\n",
    "        strategy = tf.distribute.experimental.TPUStrategy(tpu)\n",
    "    else:\n",
    "        # Default distribution strategy in Tensorflow. Works on CPU and single GPU.\n",
    "        strategy = tf.distribute.get_strategy()\n",
    "\n",
    "    global_batch_size = transformer_flags.BATCH_SIZE * strategy.num_replicas_in_sync\n",
    "\n",
    "    return tpu, strategy, global_batch_size\n",
    "\n",
    "\n",
    "tpu, strategy, global_batch_size = connect_to_TPU()\n",
    "print(\"REPLICAS: \", strategy.num_replicas_in_sync)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9z3t70bBpH_e"
   },
   "outputs": [],
   "source": [
    "def regular_encode(texts, tokenizer, maxlen=512):\n",
    "    enc_di = tokenizer.batch_encode_plus(\n",
    "        texts,\n",
    "        return_attention_mask=False,\n",
    "        return_token_type_ids=False,\n",
    "        pad_to_max_length=True,\n",
    "        max_length=maxlen,\n",
    "        truncation=True,\n",
    "    )\n",
    "\n",
    "    return np.array(enc_di[\"input_ids\"])\n",
    "\n",
    "from transformers import *\n",
    "#tokenizer = AutoTokenizer.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "#tokenizer = BertTokenizerFast.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "#tokenizer = AlbertTokenizer.from_pretrained(transformer_flags.PRETRAINED_MODEL, use_fast=True)\n",
    "tokenizer = RobertaTokenizerFast.from_pretrained(transformer_flags.PRETRAINED_MODEL)\n",
    "#tokenizer = ElectraTokenizerFast.from_pretrained(transformer_flags.PRETRAINED_MODEL)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "9r2RlmnspMg5"
   },
   "outputs": [],
   "source": [
    "def build_model():\n",
    "  input_layer = tf.keras.layers.Input((transformer_flags.MAX_LEN,), dtype=tf.int32)\n",
    "  bert_model = TFAutoModel.from_pretrained(transformer_flags.PRETRAINED_MODEL, from_pt=True)\n",
    "  #bert_model = TFAutoModel.from_pretrained('/content/github-albert-xlarge-v2/')\n",
    "  #bert_model = TFAutoModel.from_pretrained('../../github-roberta-large/')\n",
    "  output_layer = bert_model(input_layer) \n",
    "  pooled_output = tf.keras.layers.GlobalAveragePooling1D()(output_layer[0])\n",
    "  output = tf.keras.layers.Dense(3, activation='softmax')(pooled_output)\n",
    "  classifier_model = tf.keras.Model(input_layer, output)\n",
    "\n",
    "  optimizer = tf.keras.optimizers.Adam(learning_rate=transformer_flags.LR)\n",
    "  classifier_model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy', metrics=['accuracy'])\n",
    "  classifier_model.summary()\n",
    "    \n",
    "  return classifier_model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "p0vjS-AypQ8M"
   },
   "outputs": [],
   "source": [
    "all_train = train_data[0:200000]\n",
    "all_train = all_train.append(train_data[230000:])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "MCIH6njxpSqP"
   },
   "outputs": [],
   "source": [
    "question_df = all_train[all_train.label==2]\n",
    "bug_issue = all_train[all_train.label!=2]\n",
    "bug_issue = bug_issue.sample(50000)\n",
    "train = question_df.append(bug_issue)\n",
    "train = train.append(test_df)\n",
    "train = train.sample(len(train))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "id": "hnGwC4FepT9Y",
    "outputId": "4a883dfb-636b-4461-a77b-77ac3bae7997"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "2    41662\n",
       "1    39361\n",
       "0    38302\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 12,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "train.label.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 85
    },
    "id": "ZnG4Oa38pV3m",
    "outputId": "27e614ed-a8e9-42c1-9c1f-5aa0bc56c434"
   },
   "outputs": [
    {
     "data": {
      "text/plain": [
       "1    13846\n",
       "0    13278\n",
       "2     2876\n",
       "Name: label, dtype: int64"
      ]
     },
     "execution_count": 13,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "validate = train_data[200000:230000]\n",
    "validate.label.value_counts()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "i_d18azRpXCx"
   },
   "outputs": [],
   "source": [
    "with strategy.scope():\n",
    "    classifier_model = build_model()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "nZN7-MFqpYfi"
   },
   "outputs": [],
   "source": [
    "# data preparation\n",
    "X_data = regular_encode(train.text.values.tolist(), tokenizer, maxlen=transformer_flags.MAX_LEN)\n",
    "y_train = train.label.values\n",
    "print(X_data.shape, y_train.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "ii5O_lOLpfkt"
   },
   "outputs": [],
   "source": [
    "# data preparation\n",
    "X_val = regular_encode(validate.text.values.tolist(), tokenizer, maxlen=transformer_flags.MAX_LEN)\n",
    "y_val = validate.label.values\n",
    "print(X_val.shape, y_val.shape)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0KUIo4-rpf_h"
   },
   "outputs": [],
   "source": [
    "X_test = regular_encode(test_df.text.values.tolist(), tokenizer, maxlen=transformer_flags.MAX_LEN)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 428
    },
    "id": "YFs4hlpXphq6",
    "outputId": "d2c84bb3-7345-47ec-f2ba-16d565709598"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Epoch 1/2\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.data.Iterator.get_next_as_optional()` instead.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/data/ops/multi_device_iterator_ops.py:601: get_next_as_optional (from tensorflow.python.data.ops.iterator_ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Use `tf.data.Iterator.get_next_as_optional()` instead.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Gradients do not exist for variables ['tf_roberta_model/roberta/pooler/dense/kernel:0', 'tf_roberta_model/roberta/pooler/dense/bias:0'] when minimizing the loss.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/3427 [..............................] - ETA: 17:51:24 - loss: 1.1823 - accuracy: 0.2969WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0115s vs `on_train_batch_end` time: 0.1867s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0115s vs `on_train_batch_end` time: 0.1867s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3427/3427 [==============================] - ETA: 0s - loss: 0.5101 - accuracy: 0.8049WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_test_batch_end` time: 0.0545s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_test_batch_end` time: 0.0545s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "3427/3427 [==============================] - 763s 223ms/step - loss: 0.5101 - accuracy: 0.8049 - val_loss: 0.4532 - val_accuracy: 0.8303\n",
      "Epoch 2/2\n",
      "3427/3427 [==============================] - 710s 207ms/step - loss: 0.4480 - accuracy: 0.8330 - val_loss: 0.4552 - val_accuracy: 0.8275\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fa8a58216a0>"
      ]
     },
     "execution_count": 23,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 0:200000+30000 datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "dflGgwL_zy2k",
    "outputId": "23b05167-d1e7-4119-8ab5-19dbd7768de2"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/3427 [..............................] - ETA: 11:01 - loss: 0.4318 - accuracy: 0.8828WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0090s vs `on_train_batch_end` time: 0.1855s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0090s vs `on_train_batch_end` time: 0.1855s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "3427/3427 [==============================] - ETA: 0s - loss: 0.4339 - accuracy: 0.8391WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0044s vs `on_test_batch_end` time: 0.0520s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0044s vs `on_test_batch_end` time: 0.0520s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "3427/3427 [==============================] - 713s 208ms/step - loss: 0.4339 - accuracy: 0.8391 - val_loss: 0.4428 - val_accuracy: 0.8331\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fa8a1953518>"
      ]
     },
     "execution_count": 35,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 0:200000+30000 datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "mYLpG9W33dXs",
    "outputId": "c3d38a45-f13d-4a86-bcc2-a1fe90751019"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/2646 [..............................] - ETA: 8:59 - loss: 0.5107 - accuracy: 0.7812WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0108s vs `on_train_batch_end` time: 0.1876s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0108s vs `on_train_batch_end` time: 0.1876s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "2646/2646 [==============================] - ETA: 0s - loss: 0.4202 - accuracy: 0.8443WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0045s vs `on_test_batch_end` time: 0.0534s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0045s vs `on_test_batch_end` time: 0.0534s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "2646/2646 [==============================] - 593s 224ms/step - loss: 0.4202 - accuracy: 0.8443 - val_loss: 0.4492 - val_accuracy: 0.8311\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fa8a19c5748>"
      ]
     },
     "execution_count": 43,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 0:200000+30000 datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "m27rRkwA623Q",
    "outputId": "f8d84981-af83-46f6-8fdd-35956fc8e436"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/1865 [..............................] - ETA: 5:58 - loss: 0.4215 - accuracy: 0.8516WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0105s vs `on_train_batch_end` time: 0.1836s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0105s vs `on_train_batch_end` time: 0.1836s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1865/1865 [==============================] - ETA: 0s - loss: 0.3891 - accuracy: 0.8549WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0042s vs `on_test_batch_end` time: 0.0514s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0042s vs `on_test_batch_end` time: 0.0514s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "1865/1865 [==============================] - 434s 233ms/step - loss: 0.3891 - accuracy: 0.8549 - val_loss: 0.4763 - val_accuracy: 0.8216\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fa8a19dca58>"
      ]
     },
     "execution_count": 50,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 0:200000+30000 datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "1c6xKEHh9Tv9",
    "outputId": "0c618a38-7020-47b5-ee5d-8eba10351500"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "   2/1552 [..............................] - ETA: 5:30 - loss: 0.3634 - accuracy: 0.8594WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0115s vs `on_train_batch_end` time: 0.1851s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_train_batch_end` is slow compared to the batch time (batch time: 0.0115s vs `on_train_batch_end` time: 0.1851s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "1552/1552 [==============================] - ETA: 0s - loss: 0.3402 - accuracy: 0.8749WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0043s vs `on_test_batch_end` time: 0.0518s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_test_batch_end` is slow compared to the batch time (batch time: 0.0043s vs `on_test_batch_end` time: 0.0518s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\b\r",
      "1552/1552 [==============================] - 338s 218ms/step - loss: 0.3402 - accuracy: 0.8749 - val_loss: 0.5947 - val_accuracy: 0.7892\n"
     ]
    },
    {
     "data": {
      "text/plain": [
       "<tensorflow.python.keras.callbacks.History at 0x7fa89ffc31d0>"
      ]
     },
     "execution_count": 58,
     "metadata": {
      "tags": []
     },
     "output_type": "execute_result"
    }
   ],
   "source": [
    "# 0:200000+30000 datapoint\n",
    "classifier_model.fit(X_data, y_train, epochs=1,batch_size=64, validation_data=(X_val, y_val))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 139
    },
    "id": "tzTT8FpGpjSI",
    "outputId": "7e9979ad-6723-4e0b-ed35-efa186525193"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "  2/235 [..............................] - ETA: 10sWARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.0872s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.0872s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "235/235 [==============================] - 21s 90ms/step\n",
      "  2/235 [..............................] - ETA: 10sWARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.0867s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stderr",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:Callbacks method `on_predict_batch_end` is slow compared to the batch time (batch time: 0.0039s vs `on_predict_batch_end` time: 0.0867s). Check your callbacks.\n"
     ]
    },
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "235/235 [==============================] - 21s 91ms/step\n"
     ]
    }
   ],
   "source": [
    "y_val_pred = classifier_model.predict(X_val, batch_size=global_batch_size, verbose=1)\n",
    "y_test_pred = classifier_model.predict(X_test, batch_size=global_batch_size, verbose=1)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "0N1aPu_Yp2R4"
   },
   "outputs": [],
   "source": [
    "eval_df = pd.DataFrame({'X':y_val_pred.tolist(), 'y':y_val.tolist()})"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "aRjTwEXhp5eV"
   },
   "outputs": [],
   "source": [
    "eval_df.to_csv('textattack-roberta-base-MNLI-maxlen320-PL-6epochs-90K87K40K-eval-prob.csv', index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "A6RB3WLyqBfA"
   },
   "outputs": [],
   "source": [
    "submission = pd.read_csv('sample submission.csv')\n",
    "submission['label'] = y_test_pred.tolist()\n",
    "file_name = 'textattack-roberta-base-MNLI-maxlen320-PL-6epochs-90K87K40K-test-prob.csv'\n",
    "submission.to_csv(file_name, index=False)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {
    "id": "cgSW1-12qFKM"
   },
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "accelerator": "TPU",
  "colab": {
   "authorship_tag": "ABX9TyPplBqoGwIaeSt+ggBogrvi",
   "collapsed_sections": [],
   "include_colab_link": true,
   "name": "MachineHack roberta-base.ipynb",
   "provenance": []
  },
  "kernelspec": {
   "display_name": "Python 3",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.7.5"
  },
  "toc": {
   "base_numbering": 1,
   "nav_menu": {},
   "number_sections": true,
   "sideBar": true,
   "skip_h1_title": false,
   "title_cell": "Table of Contents",
   "title_sidebar": "Contents",
   "toc_cell": false,
   "toc_position": {},
   "toc_section_display": true,
   "toc_window_display": false
  }
 },
 "nbformat": 4,
 "nbformat_minor": 1
}
